{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[ 2.55025 , -1.274323],\n",
       "        [ 2.345839, -0.398576],\n",
       "        [ 1.961587, -1.238267],\n",
       "        [ 1.56649 , -3.583607],\n",
       "        [-0.150944, -1.745203]]),\n",
       " array([-24.897896, -23.759634, -26.591775, -32.467744, -33.943239]))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import sympy as sp\n",
    "\n",
    "\n",
    "#加载数据\n",
    "def load_data():\n",
    "    with open('多变量线性数据.txt') as fr:\n",
    "        lines = fr.readlines()\n",
    "\n",
    "    x = np.empty((len(lines), 2), dtype=float)\n",
    "    y = np.empty(len(lines), dtype=float)\n",
    "\n",
    "    for i in range(len(lines)):\n",
    "        line = lines[i].strip().split(',')\n",
    "        x[i] = line[:2]\n",
    "        y[i] = line[2]\n",
    "\n",
    "    return x, y\n",
    "\n",
    "\n",
    "x, y = load_data()\n",
    "x[:5], y[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#常量\n",
    "N, M = x.shape\n",
    "\n",
    "w0, w1, b = sp.symbols('w0 w1 b')\n",
    "subs = {w0: 1.0, w1: 1.0, b: 0.0}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 200.0 b^{2} + 782.575478 b w_{0} - 788.560048 b w_{1} + 11229.393686 b + 971.882447630727 w_{0}^{2} - 1462.63761572567 w_{0} w_{1} + 20571.2449100403 w_{0} + 1007.65311719167 w_{1}^{2} - 23299.5011157771 w_{1} + 160883.539418344$"
      ],
      "text/plain": [
       "200.0*b**2 + 782.575478*b*w0 - 788.560048*b*w1 + 11229.393686*b + 971.882447630727*w0**2 - 1462.63761572567*w0*w1 + 20571.2449100403*w0 + 1007.65311719167*w1**2 - 23299.5011157771*w1 + 160883.539418344"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#求loss,就是简单的做差求平方和\n",
    "loss = 0\n",
    "for i in range(N):\n",
    "    loss += (w0 * x[i, 0] + w1 * x[i, 1] + b - y[i])**2\n",
    "#化简公式\n",
    "loss = sp.cancel(loss)\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 782.575478 b + 1943.76489526145 w_{0} - 1462.63761572567 w_{1} + 20571.2449100403$"
      ],
      "text/plain": [
       "782.575478*b + 1943.76489526145*w0 - 1462.63761572567*w1 + 20571.2449100403"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取导数公式\n",
    "d_w0 = loss.diff(w0)\n",
    "d_w1 = loss.diff(w1)\n",
    "d_b = loss.diff(b)\n",
    "\n",
    "d_w0#, d_w1, d_b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(21052.372189576054, -22746.832497119423, 11223.409115999997)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#计算导数\n",
    "def get_subs():\n",
    "\n",
    "    sub_w0 = d_w0.subs(subs)\n",
    "    sub_w1 = d_w1.subs(subs)\n",
    "    sub_b = d_b.subs(subs)\n",
    "\n",
    "    return float(sub_w0), float(sub_w1), float(sub_b)\n",
    "\n",
    "\n",
    "get_subs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16985.9137646605\n",
      "89.6561038706041\n",
      "1.71257378872724\n",
      "0.0327188836781716\n",
      "0.000625097688498499\n",
      "1.19426995297545e-5\n",
      "2.28354110731743e-7\n",
      "4.54383553005755e-9\n",
      "2.72848410531878e-10\n",
      "1.29148247651756e-10\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{w0: 3.0000000436287655, w1: 2.0000000142326924, b: -30.000000115413393}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#慢慢调整w和b\n",
    "for i in range(1000):\n",
    "    sub_w0, sub_w1, sub_b = get_subs()\n",
    "\n",
    "    subs[w0] -= sub_w0 * 1e-4\n",
    "    subs[w1] -= sub_w1 * 1e-4\n",
    "    subs[b] -= sub_b * 1e-3\n",
    "\n",
    "    if i % 100 == 0:\n",
    "        print(loss.subs(subs))\n",
    "\n",
    "subs"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
